{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self, mnist=True):\n",
    "      \n",
    "        super(Net, self).__init__()\n",
    "        if mnist:\n",
    "          num_channels = 1\n",
    "        else:\n",
    "          num_channels = 3\n",
    "          \n",
    "        self.conv1 = nn.Conv2d(num_channels, 20, 5, 1)\n",
    "        self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "        self.fc1 = nn.Linear(4*4*50, 500)\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "        \n",
    "      \n",
    "    def forward(self, x):\n",
    "        \n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.max_pool2d(x, 2, 2)  \n",
    "        x = x.view(-1, 4*4*50)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.300039\n",
      "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.341256\n",
      "\n",
      "Test set: Average loss: 0.1015, Accuracy: 9664/10000 (97%)\n",
      "\n",
      "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.146154\n",
      "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.033408\n",
      "\n",
      "Test set: Average loss: 0.0609, Accuracy: 9827/10000 (98%)\n",
      "\n",
      "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.052716\n",
      "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.028806\n",
      "\n",
      "Test set: Average loss: 0.0559, Accuracy: 9810/10000 (98%)\n",
      "\n",
      "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.019357\n",
      "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.014262\n",
      "\n",
      "Test set: Average loss: 0.0410, Accuracy: 9859/10000 (99%)\n",
      "\n",
      "Train Epoch: 5 [0/60000 (0%)]\tLoss: 0.010769\n",
      "Train Epoch: 5 [32000/60000 (53%)]\tLoss: 0.040200\n",
      "\n",
      "Test set: Average loss: 0.0391, Accuracy: 9870/10000 (99%)\n",
      "\n",
      "Train Epoch: 6 [0/60000 (0%)]\tLoss: 0.128779\n",
      "Train Epoch: 6 [32000/60000 (53%)]\tLoss: 0.013484\n",
      "\n",
      "Test set: Average loss: 0.0343, Accuracy: 9890/10000 (99%)\n",
      "\n",
      "Train Epoch: 7 [0/60000 (0%)]\tLoss: 0.029768\n",
      "Train Epoch: 7 [32000/60000 (53%)]\tLoss: 0.015535\n",
      "\n",
      "Test set: Average loss: 0.0345, Accuracy: 9878/10000 (99%)\n",
      "\n",
      "Train Epoch: 8 [0/60000 (0%)]\tLoss: 0.004699\n",
      "Train Epoch: 8 [32000/60000 (53%)]\tLoss: 0.013942\n",
      "\n",
      "Test set: Average loss: 0.0397, Accuracy: 9874/10000 (99%)\n",
      "\n",
      "Train Epoch: 9 [0/60000 (0%)]\tLoss: 0.088585\n",
      "Train Epoch: 9 [32000/60000 (53%)]\tLoss: 0.006884\n",
      "\n",
      "Test set: Average loss: 0.0292, Accuracy: 9909/10000 (99%)\n",
      "\n",
      "Train Epoch: 10 [0/60000 (0%)]\tLoss: 0.122119\n",
      "Train Epoch: 10 [32000/60000 (53%)]\tLoss: 0.010755\n",
      "\n",
      "Test set: Average loss: 0.0313, Accuracy: 9895/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def train(args, model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "   \n",
    "        if batch_idx % args[\"log_interval\"] == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), loss.item()))\n",
    "\n",
    "def test(args, model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))\n",
    "\n",
    "def main():\n",
    " \n",
    "    batch_size = 64\n",
    "    test_batch_size = 64\n",
    "    epochs = 10\n",
    "    lr = 0.01\n",
    "    momentum = 0.5\n",
    "    seed = 1\n",
    "    log_interval = 500\n",
    "    save_model = False\n",
    "    no_cuda = False\n",
    "    \n",
    "    use_cuda = not no_cuda and torch.cuda.is_available()\n",
    "\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data', train=True, download=True,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=batch_size, shuffle=True, **kwargs)\n",
    "    \n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=test_batch_size, shuffle=True, **kwargs)\n",
    "    \n",
    "  \n",
    "    model = Net().to(device)\n",
    "    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)\n",
    "    args = {}\n",
    "    args[\"log_interval\"] = log_interval\n",
    "    for epoch in range(1, epochs + 1):\n",
    "        train(args, model, device, train_loader, optimizer, epoch)\n",
    "        test(args, model, device, test_loader)\n",
    "\n",
    "    if (save_model):\n",
    "        torch.save(model.state_dict(),\"mnist_cnn.pt\")\n",
    "    \n",
    "    return model\n",
    "\n",
    "model = main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import namedtuple\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])\n",
    "\n",
    "def calcScaleZeroPoint(min_val, max_val,num_bits=8):\n",
    "  # Calc Scale and zero point of next \n",
    "  qmin = 0.\n",
    "  qmax = 2.**num_bits - 1.\n",
    "\n",
    "  scale = (max_val - min_val) / (qmax - qmin)\n",
    "\n",
    "  initial_zero_point = qmin - min_val / scale\n",
    "  \n",
    "  zero_point = 0\n",
    "  if initial_zero_point < qmin:\n",
    "      zero_point = qmin\n",
    "  elif initial_zero_point > qmax:\n",
    "      zero_point = qmax\n",
    "  else:\n",
    "      zero_point = initial_zero_point\n",
    "\n",
    "  zero_point = int(zero_point)\n",
    "\n",
    "  return scale, zero_point\n",
    "\n",
    "def quantize_tensor(x, num_bits=8, min_val=None, max_val=None):\n",
    "    \n",
    "    if not min_val and not max_val: \n",
    "      min_val, max_val = x.min(), x.max()\n",
    "\n",
    "    qmin = 0.\n",
    "    qmax = 2.**num_bits - 1.\n",
    "\n",
    "    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)\n",
    "    q_x = zero_point + x / scale\n",
    "    q_x.clamp_(qmin, qmax).round_()\n",
    "    q_x = q_x.round().byte()\n",
    "    \n",
    "    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)\n",
    "\n",
    "def dequantize_tensor(q_x):\n",
    "    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def quantizeLayer(x, layer, stat, scale_x, zp_x):\n",
    "  # for both conv and linear layers\n",
    "\n",
    "  # cache old values\n",
    "  W = layer.weight.data\n",
    "  B = layer.bias.data\n",
    "\n",
    "  # quantise weights, activations are already quantised\n",
    "  w = quantize_tensor(layer.weight.data) \n",
    "  b = quantize_tensor(layer.bias.data)\n",
    "\n",
    "  layer.weight.data = w.tensor.float()\n",
    "  layer.bias.data = b.tensor.float()\n",
    "\n",
    "  # This is Quantisation Artihmetic\n",
    "  scale_w = w.scale\n",
    "  zp_w = w.zero_point\n",
    "  scale_b = b.scale\n",
    "  zp_b = b.zero_point\n",
    "  \n",
    "  scale_next, zero_point_next = calcScaleZeroPoint(min_val=stat['min'], max_val=stat['max'])\n",
    "\n",
    "  # Preparing input by shifting\n",
    "  X = x.float() - zp_x\n",
    "  layer.weight.data = scale_x * scale_w*(layer.weight.data - zp_w)\n",
    "  layer.bias.data = scale_b*(layer.bias.data + zp_b)\n",
    "\n",
    "  # All int computation\n",
    "  x = (layer(X)/ scale_next) + zero_point_next \n",
    "  \n",
    "  # Perform relu too\n",
    "  x = F.relu(x)\n",
    "\n",
    "  # Reset weights for next forward pass\n",
    "  layer.weight.data = W\n",
    "  layer.bias.data = B\n",
    "  \n",
    "  return x, scale_next, zero_point_next\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get Min and max of x tensor, and stores it\n",
    "def updateStats(x, stats, key):\n",
    "  max_val, _ = torch.max(x, dim=1)\n",
    "  min_val, _ = torch.min(x, dim=1)\n",
    "  \n",
    "  \n",
    "  if key not in stats:\n",
    "    stats[key] = {\"max\": max_val.sum(), \"min\": min_val.sum(), \"total\": 1}\n",
    "  else:\n",
    "    stats[key]['max'] += max_val.sum().item()\n",
    "    stats[key]['min'] += min_val.sum().item()\n",
    "    stats[key]['total'] += 1\n",
    "  \n",
    "  return stats\n",
    "\n",
    "# Reworked Forward Pass to access activation Stats through updateStats function\n",
    "def gatherActivationStats(model, x, stats):\n",
    "\n",
    "  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1')\n",
    "  \n",
    "  x = F.relu(model.conv1(x))\n",
    "\n",
    "  x = F.max_pool2d(x, 2, 2)\n",
    "  \n",
    "  stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')\n",
    "  \n",
    "  x = F.relu(model.conv2(x))\n",
    "\n",
    "  x = F.max_pool2d(x, 2, 2)\n",
    "\n",
    "  x = x.view(-1, 4*4*50)\n",
    "  \n",
    "  stats = updateStats(x, stats, 'fc1')\n",
    "\n",
    "  x = F.relu(model.fc1(x))\n",
    "  \n",
    "  stats = updateStats(x, stats, 'fc2')\n",
    "\n",
    "  x = model.fc2(x)\n",
    "\n",
    "  return stats\n",
    "\n",
    "# Entry function to get stats of all functions.\n",
    "def gatherStats(model, test_loader):\n",
    "    device = 'cuda'\n",
    "    \n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    stats = {}\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            stats = gatherActivationStats(model, data, stats)\n",
    "    \n",
    "    final_stats = {}\n",
    "    for key, value in stats.items():\n",
    "      final_stats[key] = { \"max\" : value[\"max\"] / value[\"total\"], \"min\" : value[\"min\"] / value[\"total\"] }\n",
    "    return final_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def quantForward(model, x, stats):\n",
    "  \n",
    "  # Quantise before inputting into incoming layers\n",
    "  x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'])\n",
    "\n",
    "  x, scale_next, zero_point_next = quantizeLayer(x.tensor, model.conv1, stats['conv2'], x.scale, x.zero_point)\n",
    "\n",
    "  x = F.max_pool2d(x, 2, 2)\n",
    "  \n",
    "  x, scale_next, zero_point_next = quantizeLayer(x, model.conv2, stats['fc1'], scale_next, zero_point_next)\n",
    "\n",
    "  x = F.max_pool2d(x, 2, 2)\n",
    "\n",
    "  x = x.view(-1, 4*4*50)\n",
    "\n",
    "  x, scale_next, zero_point_next = quantizeLayer(x, model.fc1, stats['fc2'], scale_next, zero_point_next)\n",
    "  \n",
    "  # Back to dequant for final layer\n",
    "  x = dequantize_tensor(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))\n",
    "   \n",
    "  x = model.fc2(x)\n",
    "\n",
    "  return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def testQuant(model, test_loader, quant=False, stats=None):\n",
    "    device = 'cuda'\n",
    "    \n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            if quant:\n",
    "              output = quantForward(model, data, stats)\n",
    "            else:\n",
    "              output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "q_model = copy.deepcopy(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = {'num_workers': 1, 'pin_memory': True}\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=64, shuffle=True, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 0.0313, Accuracy: 9895/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "testQuant(q_model, test_loader, quant=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'conv1': {'max': tensor(179.6297, device='cuda:0'), 'min': tensor(-27.0200, device='cuda:0')}, 'conv2': {'max': tensor(565.8288, device='cuda:0'), 'min': tensor(0., device='cuda:0')}, 'fc1': {'max': tensor(992.2306, device='cuda:0'), 'min': tensor(0., device='cuda:0')}, 'fc2': {'max': tensor(538.2664, device='cuda:0'), 'min': tensor(0., device='cuda:0')}}\n"
     ]
    }
   ],
   "source": [
    "stats = gatherStats(q_model, test_loader)\n",
    "print(stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test set: Average loss: 0.0339, Accuracy: 9889/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "testQuant(q_model, test_loader, quant=True, stats=stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
